{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/llmu/Classify_Endpoint.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xYo_6bTr21nz"
   },
   "source": [
    "# The Classify Endpoint\n",
    "\n",
    "In the text classification space, a trend is emerging where developers and teams are leveraging large language models (LLMs) when building an AI-based classifier system. This is opposed to building a system from scratch on their own, which first, requires the team to have the know-how in machine learning and engineering, and second, requires a huge amount of labeled training data to build a working solution.\n",
    "\n",
    "With LLMs, instead of having to prepare thousands of training data points, you can get up and running with just a handful of examples per class, called few-shot classification.\n",
    "\n",
    "In this notebook, you'll learn how to build a classifier with Cohere's Classify endpoint through few-shot learning. This notebook accompanies the [Classify endpoint lesson](https://docs.cohere.com/docs/classify-endpoint/) of LLM University."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1ys2CHEgurfe"
   },
   "source": [
    "## Setup\n",
    "\n",
    "We'll start by installing the tools we'll need and then importing them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "QdEURifRRUgy"
   },
   "outputs": [],
   "source": [
    "! pip install cohere -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cohere\n",
    "from cohere import ClassifyExample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fill in your Cohere API key in the next cell. To do this, begin by [signing up to Cohere](https://os.cohere.ai/) (for free!) if you haven't yet. Then get your API key [here](https://dashboard.cohere.com/api-keys)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paste your API key here. Remember to not share publicly\n",
    "co = cohere.Client(\"COHERE_API_KEY\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DtHJ02d7Rz8q"
   },
   "source": [
    "## Prepare Examples and Input\n",
    "\n",
    "A typical machine learning model requires many training examples to perform text classification, but with the Classify endpoint, you can get started with as few as 5 examples per class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "-lSi_UmQEfy_"
   },
   "outputs": [],
   "source": [
    "# Create the training examples for the classifier\n",
    "examples = [ClassifyExample(text=\"I’m so proud of you\", label=\"positive\"), \n",
    "            ClassifyExample(text=\"What a great time to be alive\", label=\"positive\"), \n",
    "            ClassifyExample(text=\"That’s awesome work\", label=\"positive\"), \n",
    "            ClassifyExample(text=\"The service was amazing\", label=\"positive\"), \n",
    "            ClassifyExample(text=\"I love my family\", label=\"positive\"), \n",
    "            ClassifyExample(text=\"They don't care about me\", label=\"negative\"), \n",
    "            ClassifyExample(text=\"I hate this place\", label=\"negative\"), \n",
    "            ClassifyExample(text=\"The most ridiculous thing I've ever heard\", label=\"negative\"), \n",
    "            ClassifyExample(text=\"I am really frustrated\", label=\"negative\"), \n",
    "            ClassifyExample(text=\"This is so unfair\", label=\"negative\"),\n",
    "            ClassifyExample(text=\"This made me think\", label=\"neutral\"), \n",
    "            ClassifyExample(text=\"The good old days\", label=\"neutral\"), \n",
    "            ClassifyExample(text=\"What's the difference\", label=\"neutral\"), \n",
    "            ClassifyExample(text=\"You can't ignore this\", label=\"neutral\"), \n",
    "            ClassifyExample(text=\"That's how I see it\", label=\"neutral\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "QpC1Z4xEEWs0"
   },
   "outputs": [],
   "source": [
    "# Enter the inputs to be classified\n",
    "inputs = [\"Hello, world! What a beautiful day\",\n",
    "          \"It was a great time with great people\",\n",
    "          \"Great place to work\",\n",
    "          \"That was a wonderful evening\",\n",
    "          \"Maybe this is why\",\n",
    "          \"Let's start again\",\n",
    "          \"That's how I see it\",\n",
    "          \"These are all facts\",\n",
    "          \"This is the worst thing\",\n",
    "          \"I cannot stand this any longer\",\n",
    "          \"This is really annoying\",\n",
    "          \"I am just plain fed up\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "etTr200IRszm"
   },
   "outputs": [],
   "source": [
    "def classify_text(inputs, examples):\n",
    "    \"\"\"\n",
    "    Classifies a list of input texts given the examples\n",
    "    Arguments:\n",
    "        model (str): identifier of the model\n",
    "        inputs (list[str]): a list of input texts to be classified\n",
    "        examples (list[Example]): a list of example texts and class labels\n",
    "    Returns:\n",
    "        classifications (list): each result contains the text, labels, and conf values\n",
    "    \"\"\"\n",
    "    # Classify text by calling the Classify endpoint\n",
    "    response = co.classify(\n",
    "        model='embed-english-v2.0',\n",
    "        inputs=inputs,\n",
    "        examples=examples)\n",
    "\n",
    "    classifications = response.classifications\n",
    "\n",
    "    return classifications\n",
    "\n",
    "# Classify the inputs\n",
    "predictions = classify_text(inputs, examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kOdL3U0jRswU",
    "outputId": "93ea6111-ef75-4593-971c-b20b5dfb3d22"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: Hello, world! What a beautiful day\n",
      "Prediction: positive\n",
      "Confidence: 0.84\n",
      "----------\n",
      "Input: It was a great time with great people\n",
      "Prediction: positive\n",
      "Confidence: 0.99\n",
      "----------\n",
      "Input: Great place to work\n",
      "Prediction: positive\n",
      "Confidence: 0.91\n",
      "----------\n",
      "Input: That was a wonderful evening\n",
      "Prediction: positive\n",
      "Confidence: 0.96\n",
      "----------\n",
      "Input: Maybe this is why\n",
      "Prediction: neutral\n",
      "Confidence: 0.70\n",
      "----------\n",
      "Input: Let's start again\n",
      "Prediction: neutral\n",
      "Confidence: 0.83\n",
      "----------\n",
      "Input: That's how I see it\n",
      "Prediction: neutral\n",
      "Confidence: 1.00\n",
      "----------\n",
      "Input: These are all facts\n",
      "Prediction: neutral\n",
      "Confidence: 0.78\n",
      "----------\n",
      "Input: This is the worst thing\n",
      "Prediction: negative\n",
      "Confidence: 0.93\n",
      "----------\n",
      "Input: I cannot stand this any longer\n",
      "Prediction: negative\n",
      "Confidence: 0.93\n",
      "----------\n",
      "Input: This is really annoying\n",
      "Prediction: negative\n",
      "Confidence: 0.99\n",
      "----------\n",
      "Input: I am just plain fed up\n",
      "Prediction: negative\n",
      "Confidence: 1.00\n",
      "----------\n"
     ]
    }
   ],
   "source": [
    "# Display the classification outcomes\n",
    "classes = [\"positive\", \"negative\", \"neutral\"]\n",
    "for inp,pred in zip(inputs, predictions):\n",
    "    class_pred = pred.prediction\n",
    "    class_idx = classes.index(class_pred)\n",
    "    class_conf = pred.confidence\n",
    "\n",
    "    print(f\"Input: {inp}\")\n",
    "    print(f\"Prediction: {class_pred}\")\n",
    "    print(f\"Confidence: {class_conf:.2f}\")\n",
    "    print(\"-\"*10)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "1fb8019e3560b882083e525615cf48e713d3a7345a15eb723d805e91aa410aac"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
